import cv2
import numpy as np
import matplotlib.pyplot as plt


def fourier_denoise(image, cutoff_frequency):
    # 将图像转换为频率域
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)

    # 创建一个与图像大小相同的掩码，中心为低通滤波区域
    rows, cols = image.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.zeros((rows, cols), np.uint8)
    mask[
        crow - cutoff_frequency : crow + cutoff_frequency,
        ccol - cutoff_frequency : ccol + cutoff_frequency,
    ] = 1

    # 应用低通滤波器
    fshift_filtered = fshift * mask

    # 将过滤后的频率域图像转换回空间域
    f_ishift = np.fft.ifftshift(fshift_filtered)
    img_back = np.fft.ifft2(f_ishift)
    img_back = np.abs(img_back)

    return img_back


def denoise_image(image_path, cutoff_frequency):
    # 读取图像
    img = cv2.imread(image_path, cv2.IMREAD_COLOR)

    # 分别对每个通道进行去噪
    denoised_channels = []
    for i in range(3):
        channel = img[:, :, i]
        denoised_channel = fourier_denoise(channel, cutoff_frequency)
        denoised_channels.append(denoised_channel)

    # 合并去噪后的通道
    denoised_img = np.stack(denoised_channels, axis=2)

    # 将结果转换为合适的数据类型和范围
    denoised_img = np.clip(denoised_img, 0, 255).astype(np.uint8)

    return denoised_img


def display_images(original_img, denoised_img):
    # 显示原图和处理后的图像
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    ax = axes.ravel()
    ax[0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
    ax[0].set_title("Original Image")
    ax[0].axis("off")
    ax[1].imshow(cv2.cvtColor(denoised_img, cv2.COLOR_BGR2RGB))
    ax[1].set_title("Denoised Image")
    ax[1].axis("off")
    plt.tight_layout()
    plt.show()


# 使用示例
# 示例：对图像进行降噪
# 这里需要一张示例图片的路径和截止频率
# 截止频率越低，去噪效果越明显
image_path = "image.png"  # 替换为您的图片路径
cutoff_frequency = 150  # 根据需要调整截止频率
denoised_img = denoise_image(image_path, cutoff_frequency)
display_images(cv2.imread(image_path), denoised_img)
